{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "cqCt_GhvCnwY"
      },
      "source": [
        "# VQ-VAE training example\n",
        "\n",
        "Demonstration of how to train the model specified in https://arxiv.org/abs/1711.00937\n",
        "\n",
        "On Mac and Linux, simply execute each cell in turn."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        },
        "colab_type": "code",
        "id": "95YuC82P35Of"
      },
      "outputs": [],
      "source": [
        "from __future__ import print_function\n",
        "\n",
        "import os\n",
        "import subprocess\n",
        "import tempfile\n",
        "\n",
        "import matplotlib.pyplot as plt\n",
        "import numpy as np\n",
        "import sonnet as snt\n",
        "import tensorflow as tf\n",
        "import tarfile\n",
        "\n",
        "from six.moves import cPickle\n",
        "from six.moves import urllib\n",
        "from six.moves import xrange"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "DT8fKmqQC35h"
      },
      "source": [
        "# Download Cifar10 data\n",
        "This requires a connection to the internet and will download ~160MB.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          },
          "height": 51
        },
        "colab_type": "code",
        "executionInfo": {
          "elapsed": 22057,
          "status": "ok",
          "timestamp": 1527689581039,
          "user": {
            "displayName": "",
            "photoUrl": "",
            "userId": ""
          },
          "user_tz": -60
        },
        "id": "mR0lkHXDC3Pz",
        "outputId": "762b7d4f-39a9-4db5-91b8-808246cbfb1b"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "using temporary directory /tmp/tmpm_QQkc\n",
            "extracted data files to /tmp/tmpm_QQkc\n"
          ]
        }
      ],
      "source": [
        "data_path = \"https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz\"\n",
        "\n",
        "local_data_dir = tempfile.mkdtemp()  # Change this as needed\n",
        "tf.gfile.MakeDirs(local_data_dir)\n",
        "\n",
        "url = urllib.request.urlopen(data_path)\n",
        "archive = tarfile.open(fileobj=url, mode='r|gz') # read a .tar.gz stream\n",
        "archive.extractall(local_data_dir)\n",
        "url.close()\n",
        "archive.close()\n",
        "print('extracted data files to %s' % local_data_dir)\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "lUgvEhfJyQLZ"
      },
      "source": [
        "# Load the data into Numpy\n",
        "We compute the variance of the whole training set to normalise the Mean Squared Error below.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        },
        "colab_type": "code",
        "id": "9C-V2D6RSQwl"
      },
      "outputs": [],
      "source": [
        "def unpickle(filename):\n",
        "  with open(filename, 'rb') as fo:\n",
        "    return cPickle.load(fo, encoding='latin1')\n",
        "  \n",
        "def reshape_flattened_image_batch(flat_image_batch):\n",
        "  return flat_image_batch.reshape(-1, 3, 32, 32).transpose([0, 2, 3, 1])  # convert from NCHW to NHWC\n",
        "\n",
        "def combine_batches(batch_list):\n",
        "  images = np.vstack([reshape_flattened_image_batch(batch['data'])\n",
        "                      for batch in batch_list])\n",
        "  labels = np.vstack([np.array(batch['labels']) for batch in batch_list]).reshape(-1, 1)\n",
        "  return {'images': images, 'labels': labels}\n",
        "  \n",
        "\n",
        "train_data_dict = combine_batches([\n",
        "    unpickle(os.path.join(local_data_dir,\n",
        "                          'cifar-10-batches-py/data_batch_%d' % i))\n",
        "    for i in range(1,5)\n",
        "])\n",
        "\n",
        "valid_data_dict = combine_batches([\n",
        "    unpickle(os.path.join(local_data_dir,\n",
        "                          'cifar-10-batches-py/data_batch_5'))])\n",
        "\n",
        "test_data_dict = combine_batches([\n",
        "    unpickle(os.path.join(local_data_dir, 'cifar-10-batches-py/test_batch'))])\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        },
        "colab_type": "code",
        "id": "cIRl2ZtxoKNz"
      },
      "outputs": [],
      "source": [
        "def cast_and_normalise_images(data_dict):\n",
        "  \"\"\"Convert images to floating point with the range [0.5, 0.5]\"\"\"\n",
        "  images = data_dict['images']\n",
        "  data_dict['images'] = (tf.cast(images, tf.float32) / 255.0) - 0.5\n",
        "  return data_dict\n",
        "\n",
        "data_variance = np.var(train_data_dict['images'] / 255.0)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "Jse__pEBAkvI"
      },
      "source": [
        "# Encoder \u0026 Decoder Architecture\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        },
        "colab_type": "code",
        "id": "1gwD36Vr6KqA"
      },
      "outputs": [],
      "source": [
        "def residual_stack(h, num_hiddens, num_residual_layers, num_residual_hiddens):\n",
        "  for i in range(num_residual_layers):\n",
        "    h_i = tf.nn.relu(h)\n",
        "\n",
        "    h_i = snt.Conv2D(\n",
        "        output_channels=num_residual_hiddens,\n",
        "        kernel_shape=(3, 3),\n",
        "        stride=(1, 1),\n",
        "        name=\"res3x3_%d\" % i)(h_i)\n",
        "    h_i = tf.nn.relu(h_i)\n",
        "\n",
        "    h_i = snt.Conv2D(\n",
        "        output_channels=num_hiddens,\n",
        "        kernel_shape=(1, 1),\n",
        "        stride=(1, 1),\n",
        "        name=\"res1x1_%d\" % i)(h_i)\n",
        "    h += h_i\n",
        "  return tf.nn.relu(h)\n",
        "\n",
        "class Encoder(snt.AbstractModule):\n",
        "  def __init__(self, num_hiddens, num_residual_layers, num_residual_hiddens,\n",
        "               name='encoder'):\n",
        "    super(Encoder, self).__init__(name=name)\n",
        "    self._num_hiddens = num_hiddens\n",
        "    self._num_residual_layers = num_residual_layers\n",
        "    self._num_residual_hiddens = num_residual_hiddens\n",
        "    \n",
        "  def _build(self, x):\n",
        "    h = snt.Conv2D(\n",
        "        output_channels=self._num_hiddens / 2,\n",
        "        kernel_shape=(4, 4),\n",
        "        stride=(2, 2),\n",
        "        name=\"enc_1\")(x)\n",
        "    h = tf.nn.relu(h)\n",
        "\n",
        "    h = snt.Conv2D(\n",
        "        output_channels=self._num_hiddens,\n",
        "        kernel_shape=(4, 4),\n",
        "        stride=(2, 2),\n",
        "        name=\"enc_2\")(h)\n",
        "    h = tf.nn.relu(h)\n",
        "\n",
        "    h = snt.Conv2D(\n",
        "        output_channels=self._num_hiddens,\n",
        "        kernel_shape=(3, 3),\n",
        "        stride=(1, 1),\n",
        "        name=\"enc_3\")(h)\n",
        "\n",
        "    h = residual_stack(\n",
        "        h,\n",
        "        self._num_hiddens,\n",
        "        self._num_residual_layers,\n",
        "        self._num_residual_hiddens)\n",
        "    return h\n",
        "\n",
        "class Decoder(snt.AbstractModule):\n",
        "  def __init__(self, num_hiddens, num_residual_layers, num_residual_hiddens,\n",
        "               name='decoder'):\n",
        "    super(Decoder, self).__init__(name=name)\n",
        "    self._num_hiddens = num_hiddens\n",
        "    self._num_residual_layers = num_residual_layers\n",
        "    self._num_residual_hiddens = num_residual_hiddens\n",
        "  \n",
        "  def _build(self, x):\n",
        "    h = snt.Conv2D(\n",
        "      output_channels=self._num_hiddens,\n",
        "      kernel_shape=(3, 3),\n",
        "      stride=(1, 1),\n",
        "      name=\"dec_1\")(x)\n",
        "\n",
        "    h = residual_stack(\n",
        "        h,\n",
        "        self._num_hiddens,\n",
        "        self._num_residual_layers,\n",
        "        self._num_residual_hiddens)\n",
        "\n",
        "    h = snt.Conv2DTranspose(\n",
        "        output_channels=int(self._num_hiddens / 2),\n",
        "        output_shape=None,\n",
        "        kernel_shape=(4, 4),\n",
        "        stride=(2, 2),\n",
        "        name=\"dec_2\")(h)\n",
        "    h = tf.nn.relu(h)\n",
        "\n",
        "    x_recon = snt.Conv2DTranspose(\n",
        "        output_channels=3,\n",
        "        output_shape=None,\n",
        "        kernel_shape=(4, 4),\n",
        "        stride=(2, 2),\n",
        "        name=\"dec_3\")(h)\n",
        "\n",
        "    return x_recon\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "FF7WaOn-s7En"
      },
      "source": [
        "# Build Graph and train"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        },
        "colab_type": "code",
        "id": "owGEoOkO4ttk"
      },
      "outputs": [],
      "source": [
        "tf.reset_default_graph()\n",
        "\n",
        "# Set hyper-parameters.\n",
        "batch_size = 32\n",
        "image_size = 32\n",
        "\n",
        "# 100k steps should take \u003c 30 minutes on a modern (\u003e= 2017) GPU.\n",
        "num_training_updates = 50000\n",
        "\n",
        "num_hiddens = 128\n",
        "num_residual_hiddens = 32\n",
        "num_residual_layers = 2\n",
        "# These hyper-parameters define the size of the model (number of parameters and layers).\n",
        "# The hyper-parameters in the paper were (For ImageNet):\n",
        "# batch_size = 128\n",
        "# image_size = 128\n",
        "# num_hiddens = 128\n",
        "# num_residual_hiddens = 32\n",
        "# num_residual_layers = 2\n",
        "\n",
        "# This value is not that important, usually 64 works.\n",
        "# This will not change the capacity in the information-bottleneck.\n",
        "embedding_dim = 64\n",
        "\n",
        "# The higher this value, the higher the capacity in the information bottleneck.\n",
        "num_embeddings = 512\n",
        "\n",
        "# commitment_cost should be set appropriately. It's often useful to try a couple\n",
        "# of values. It mostly depends on the scale of the reconstruction cost\n",
        "# (log p(x|z)). So if the reconstruction cost is 100x higher, the\n",
        "# commitment_cost should also be multiplied with the same amount.\n",
        "commitment_cost = 0.25\n",
        "\n",
        "# Use EMA updates for the codebook (instead of the Adam optimizer).\n",
        "# This typically converges faster, and makes the model less dependent on choice\n",
        "# of the optimizer. In the VQ-VAE paper EMA updates were not used (but was\n",
        "# developed afterwards). See Appendix of the paper for more details.\n",
        "vq_use_ema = False\n",
        "\n",
        "# This is only used for EMA updates.\n",
        "decay = 0.99\n",
        "\n",
        "learning_rate = 3e-4\n",
        "\n",
        "\n",
        "# Data Loading.\n",
        "train_dataset_iterator = (\n",
        "    tf.data.Dataset.from_tensor_slices(train_data_dict)\n",
        "    .map(cast_and_normalise_images)\n",
        "    .shuffle(10000)\n",
        "    .repeat(-1)  # repeat indefinitely\n",
        "    .batch(batch_size)).make_one_shot_iterator()\n",
        "valid_dataset_iterator = (\n",
        "    tf.data.Dataset.from_tensor_slices(valid_data_dict)\n",
        "    .map(cast_and_normalise_images)\n",
        "    .repeat(1)  # 1 epoch\n",
        "    .batch(batch_size)).make_initializable_iterator()\n",
        "train_dataset_batch = train_dataset_iterator.get_next()\n",
        "valid_dataset_batch = valid_dataset_iterator.get_next()\n",
        "\n",
        "def get_images(sess, subset='train'):\n",
        "  if subset == 'train':\n",
        "    return sess.run(train_dataset_batch)['images']\n",
        "  elif subset == 'valid':\n",
        "    return sess.run(valid_dataset_batch)['images']\n",
        "\n",
        "\n",
        "# Build modules.\n",
        "encoder = Encoder(num_hiddens, num_residual_layers, num_residual_hiddens)\n",
        "decoder = Decoder(num_hiddens, num_residual_layers, num_residual_hiddens)\n",
        "pre_vq_conv1 = snt.Conv2D(output_channels=embedding_dim,\n",
        "    kernel_shape=(1, 1),\n",
        "    stride=(1, 1),\n",
        "    name=\"to_vq\")\n",
        "\n",
        "if vq_use_ema:\n",
        "  vq_vae = snt.nets.VectorQuantizerEMA(\n",
        "      embedding_dim=embedding_dim,\n",
        "      num_embeddings=num_embeddings,\n",
        "      commitment_cost=commitment_cost,\n",
        "      decay=decay)\n",
        "else:\n",
        "  vq_vae = snt.nets.VectorQuantizer(\n",
        "      embedding_dim=embedding_dim,\n",
        "      num_embeddings=num_embeddings,\n",
        "      commitment_cost=commitment_cost)\n",
        "\n",
        "# Process inputs with conv stack, finishing with 1x1 to get to correct size.\n",
        "x = tf.placeholder(tf.float32, shape=(None, image_size, image_size, 3))\n",
        "z = pre_vq_conv1(encoder(x))\n",
        "\n",
        "# vq_output_train[\"quantize\"] are the quantized outputs of the encoder.\n",
        "# That is also what is used during training with the straight-through estimator. \n",
        "# To get the one-hot coded assignments use vq_output_train[\"encodings\"] instead.\n",
        "# These encodings will not pass gradients into to encoder, \n",
        "# but can be used to train a PixelCNN on top afterwards.\n",
        "\n",
        "# For training\n",
        "vq_output_train = vq_vae(z, is_training=True)\n",
        "x_recon = decoder(vq_output_train[\"quantize\"])\n",
        "recon_error = tf.reduce_mean((x_recon - x)**2) / data_variance  # Normalized MSE\n",
        "loss = recon_error + vq_output_train[\"loss\"]\n",
        "\n",
        "# For evaluation, make sure is_training=False!\n",
        "vq_output_eval = vq_vae(z, is_training=False)\n",
        "x_recon_eval = decoder(vq_output_eval[\"quantize\"])\n",
        "\n",
        "# The following is a useful value to track during training.\n",
        "# It indicates how many codes are 'active' on average.\n",
        "perplexity = vq_output_train[\"perplexity\"] \n",
        "\n",
        "# Create optimizer and TF session.\n",
        "optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)\n",
        "train_op = optimizer.minimize(loss)\n",
        "sess = tf.train.SingularMonitoredSession()\n",
        "\n",
        "# Train.\n",
        "train_res_recon_error = []\n",
        "train_res_perplexity = []\n",
        "for i in xrange(num_training_updates):\n",
        "  feed_dict = {x: get_images(sess)}\n",
        "  results = sess.run([train_op, recon_error, perplexity],\n",
        "                     feed_dict=feed_dict)\n",
        "  train_res_recon_error.append(results[1])\n",
        "  train_res_perplexity.append(results[2])\n",
        "  \n",
        "  if (i+1) % 100 == 0:\n",
        "    print('%d iterations' % (i+1))\n",
        "    print('recon_error: %.3f' % np.mean(train_res_recon_error[-100:]))\n",
        "    print('perplexity: %.3f' % np.mean(train_res_perplexity[-100:]))\n",
        "    print()\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "m2hNyAnhs-1f"
      },
      "source": [
        "# Plot loss"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        },
        "colab_type": "code",
        "id": "2vo-lDyRomKD"
      },
      "outputs": [],
      "source": [
        "f = plt.figure(figsize=(16,8))\n",
        "ax = f.add_subplot(1,2,1)\n",
        "ax.plot(train_res_recon_error)\n",
        "ax.set_yscale('log')\n",
        "ax.set_title('NMSE.')\n",
        "\n",
        "ax = f.add_subplot(1,2,2)\n",
        "ax.plot(train_res_perplexity)\n",
        "ax.set_title('Average codebook usage (perplexity).')\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "Lyj1CCKptCZz"
      },
      "source": [
        "# View reconstructions"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        },
        "colab_type": "code",
        "id": "rM9zj7ZiPZBG"
      },
      "outputs": [],
      "source": [
        "# Reconstructions\n",
        "sess.run(valid_dataset_iterator.initializer)\n",
        "train_originals = get_images(sess, subset='train')\n",
        "train_reconstructions = sess.run(x_recon_eval, feed_dict={x: train_originals})\n",
        "valid_originals = get_images(sess, subset='valid')\n",
        "valid_reconstructions = sess.run(x_recon_eval, feed_dict={x: valid_originals})\n",
        "\n",
        "def convert_batch_to_image_grid(image_batch):\n",
        "  reshaped = (image_batch.reshape(4, 8, 32, 32, 3)\n",
        "              .transpose(0, 2, 1, 3, 4)\n",
        "              .reshape(4 * 32, 8 * 32, 3))\n",
        "  return reshaped + 0.5\n",
        "\n",
        "\n",
        "\n",
        "f = plt.figure(figsize=(16,8))\n",
        "ax = f.add_subplot(2,2,1)\n",
        "ax.imshow(convert_batch_to_image_grid(train_originals),\n",
        "          interpolation='nearest')\n",
        "ax.set_title('training data originals')\n",
        "plt.axis('off')\n",
        "\n",
        "ax = f.add_subplot(2,2,2)\n",
        "ax.imshow(convert_batch_to_image_grid(train_reconstructions),\n",
        "          interpolation='nearest')\n",
        "ax.set_title('training data reconstructions')\n",
        "plt.axis('off')\n",
        "\n",
        "ax = f.add_subplot(2,2,3)\n",
        "ax.imshow(convert_batch_to_image_grid(valid_originals),\n",
        "          interpolation='nearest')\n",
        "ax.set_title('validation data originals')\n",
        "plt.axis('off')\n",
        "\n",
        "ax = f.add_subplot(2,2,4)\n",
        "ax.imshow(convert_batch_to_image_grid(valid_reconstructions),\n",
        "          interpolation='nearest')\n",
        "ax.set_title('validation data reconstructions')\n",
        "plt.axis('off')\n"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "default_view": {},
      "name": "VQ-VAE training example",
      "provenance": [],
      "version": "0.3.2",
      "views": {}
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
